import torch
import torch.nn as nn
from ..utils import create_norm

class ResMLPDecoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num):
        super().__init__()
        self.layers = nn.ModuleList()
        assert num_layers > 1, 'At least two layer for MLPs.'
        for i in range(num_layers - 1):
            dim = hidden_dim if i>0 else in_dim
            self.layers.append(nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.PReLU(),
                nn.Dropout(dropout),
                create_norm(norm, hidden_dim)
            ))
        self.out_layer = nn.Sequential(
            nn.Linear(hidden_dim * (num_layers - 1), out_dim),
            nn.PReLU(),
        )
        self.batch_emb = nn.Embedding(batch_num, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, x_dict):
        hist = []
        batch_labels = x_dict['batch']
        x = x_dict['h']
        for layer in self.layers:
            x = layer(x)
            x = x + self.layer_norm(self.batch_emb(batch_labels))
            hist.append(x)
        return {'recon': self.out_layer(torch.cat(hist, 1)), 'latent': x_dict['h']}

class MLPDecoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num, out_act=None):
        super().__init__()
        self.layers = nn.ModuleList()
        assert num_layers > 1, 'At least two layer for MLPs.'
        for i in range(num_layers-1):
            dim = hidden_dim if i > 0 else in_dim
            self.layers.append(nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.PReLU(),
                nn.Dropout(dropout),
                # create_norm(norm, hidden_dim)
            ))
        self.out_layer = [nn.Linear(hidden_dim, out_dim)]
        if out_act is not None:
            self.out_layer.append(out_act)
        self.out_layer = nn.Sequential(*self.out_layer)
        self.batch_emb = nn.Embedding(batch_num, in_dim)
        self.layer_norm = nn.LayerNorm(in_dim)
    
    def forward(self, x_dict):
        batch_labels = x_dict['batch']
        x = x_dict['h']
        x = x + self.layer_norm(self.batch_emb(batch_labels))
        for layer in self.layers:
            x = layer(x)
        return {'recon': self.out_layer(x), 'latent': x_dict['h']}